"""
Mask R-CNN
Multi-GPU Support for Keras.

Copyright (c) 2017 Matterport, Inc.
Licensed under the MIT License (see LICENSE for details)
Written by Waleed Abdulla

Ideas and a small code snippets from these sources:
https://github.com/fchollet/keras/issues/2436
https://medium.com/@kuza55/transparent-multi-gpu-training-on-tensorflow-with-keras-8b0016fd9012
https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/
https://github.com/fchollet/keras/blob/master/keras/utils/training_utils.py
"""

import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.keras.layers as KL
import tensorflow.keras.models as KM

# import keras.backend as K
# import keras.layers as KL
# import keras.models as KM


def make_parallel(keras_model, gpu_count):
    """Creates a new wrapper models that consists of multiple replicas of
    the original models placed on different GPUs.
    Args:
        keras_model: the input models to replicate on multiple gpus
        gpu_count: the number of replicas to build
    Returns:
        Multi-gpu models
    """
    # Slice inputs. Slice inputs on the CPU to avoid sending a copy
    # of the full inputs to all GPUs. Saves on bandwidth and memory.
    input_slices = {name: tf.split(x, gpu_count)
                    for name, x in zip(keras_model.input_names,
                                       keras_model.inputs)}

    output_names = keras_model.output_names
    outputs_all = []
    for i in range(len(keras_model.outputs)):
        outputs_all.append([])

    # Run the models call() on each GPU to place the ops there
    for i in range(gpu_count):
        with tf.device('/gpu:%d' % i):
            with tf.name_scope('tower_%d' % i):
                # Run a slice of inputs through this replica
                zipped_inputs = zip(keras_model.input_names,
                                    keras_model.inputs)
                inputs = [
                    KL.Lambda(lambda s: input_slices[name][i],
                              output_shape=lambda s: (None,) + s[1:])(tensor)
                    for name, tensor in zipped_inputs]
                # Create the models replica and get the outputs
                outputs = keras_model(inputs)
                if not isinstance(outputs, list):
                    outputs = [outputs]
                # Save the outputs for merging back together later
                for l, o in enumerate(outputs):
                    outputs_all[l].append(o)

    # Merge outputs on CPU
    with tf.device('/cpu:0'):
        merged = []
        for outputs, name in zip(outputs_all, output_names):
            # Concatenate or average outputs?
            # Outputs usually have a batch dimension and we concatenate
            # across it. If they don't, then the output is likely a loss
            # or a metric value that gets averaged across the batch.
            # Keras expects losses and metrics to be scalars.
            if K.int_shape(outputs[0]) == ():
                # Average
                m = KL.Lambda(lambda o: tf.add_n(
                    o) / len(outputs), name=name)(outputs)
            else:
                # Concatenate
                m = KL.Concatenate(axis=0, name=name)(outputs)
            merged.append(m)
    return merged


class ParallelModel(KM.Model):
    """Subclasses the standard Keras Model and adds multi-GPU support.
    It works by creating a copy of the models on each GPU. Then it slices
    the inputs and sends a slice to each copy of the models, and then
    merges the outputs together and applies the loss on the combined
    outputs.
    """

    def __init__(self, keras_model, gpu_count):
        """Class constructor.
        keras_model: The Keras models to parallelize
        gpu_count: Number of GPUs. Must be > 1
        """
        merged_outputs = make_parallel(
            keras_model=keras_model, gpu_count=gpu_count)
        super(ParallelModel, self).__init__(inputs=keras_model.inputs,
                                            outputs=merged_outputs)
        self.inner_model = keras_model

    def __getattribute__(self, attrname):
        """Redirect loading and saving methods to the inner models. That's where
        the weights are stored."""
        if 'load' in attrname or 'save' in attrname:
            return getattr(self.inner_model, attrname)
        return super(ParallelModel, self).__getattribute__(attrname)

    def summary(self, *args, **kwargs):
        """Override summary() to display summaries of both, the wrapper
        and inner models."""
        super(ParallelModel, self).summary(*args, **kwargs)
        self.inner_model.summary(*args, **kwargs)


if __name__ == "__main__":
    # Testing code below. It creates a simple models to train on MNIST and
    # tries to run it on 2 GPUs. It saves the graph so it can be viewed
    # in TensorBoard. Run it as:
    #
    # python3 parallel_model.py

    import os
    import numpy as np
    import keras.optimizers
    from keras.datasets import mnist
    from keras.preprocessing.image import ImageDataGenerator

    GPU_COUNT = 2

    # Root directory of the project
    ROOT_DIR = os.path.abspath("../")

    # Directory to save logs and trained models
    MODEL_DIR = os.path.join(ROOT_DIR, "logs")


    def build_model(x_train, num_classes):
        # Reset default graph. Keras leaves old ops in the graph,
        # which are ignored for execution but clutter graph
        # visualization in TensorBoard.
        tf.reset_default_graph()

        inputs = KL.Input(shape=x_train.shape[1:], name="input_image")
        x = KL.Conv2D(32, (3, 3), activation='relu', padding="same",
                      name="conv1")(inputs)
        x = KL.Conv2D(64, (3, 3), activation='relu', padding="same",
                      name="conv2")(x)
        x = KL.MaxPooling2D(pool_size=(2, 2), name="pool1")(x)
        x = KL.Flatten(name="flat1")(x)
        x = KL.Dense(128, activation='relu', name="dense1")(x)
        x = KL.Dense(num_classes, activation='softmax', name="dense2")(x)

        return KM.Model(inputs, x, "digit_classifier_model")


    # Load MNIST Data
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = np.expand_dims(x_train, -1).astype('float32') / 255
    x_test = np.expand_dims(x_test, -1).astype('float32') / 255

    print('x_train shape:', x_train.shape)
    print('x_test shape:', x_test.shape)

    # Build data generator and models
    datagen = ImageDataGenerator()
    model = build_model(x_train, 10)

    # Add multi-GPU support.
    model = ParallelModel(model, GPU_COUNT)

    optimizer = keras.optimizers.SGD(lr=0.01, momentum=0.9, clipnorm=5.0)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer, metrics=['accuracy'])

    model.summary()

    # Train
    model.fit_generator(
        datagen.flow(x_train, y_train, batch_size=64),
        steps_per_epoch=50, epochs=10, verbose=1,
        validation_data=(x_test, y_test),
        callbacks=[keras.callbacks.TensorBoard(log_dir=MODEL_DIR,
                                               write_graph=True)]
    )
